{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tree LSTM modeling for semantic relatedness\n",
    "\n",
    "### Sentences involving Compositional Knowledge\n",
    "This tutorial walks through training a child-sum Tree LSTM model for analyzing semantic relatedness of sentence pairs given their dependency parse trees.\n",
    "\n",
    "### Preliminaries\n",
    "Requires the latest MXNet with the new `gluon` interface. One can either build from source or install the pre-release package through `pip install --pre mxnet`. Use of GPUs is preferred if one wants to run the complete training to match the state-of-the-art results.\n",
    "\n",
    "Besides, to show a progress meter, one should install the `tqdm` (\"progress\" in Arabic) through  `pip install tqdm`. One should also install the HTTP library through `pip install requests`.\n",
    "\n",
    "\n",
    "### Inspiration\n",
    "This tutorial borrows heavily from [Pytorch](https://github.com/dasguptar/treelstm.pytorch) example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from mxnet.gluon import Block, nn\n",
    "from mxnet.gluon.parameter import Parameter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Tree(object):\n",
    "    def __init__(self, idx):\n",
    "        self.children = []\n",
    "        self.idx = idx\n",
    "\n",
    "    def __repr__(self):\n",
    "        if self.children:\n",
    "            return '{0}: {1}'.format(self.idx, str(self.children))\n",
    "        else:\n",
    "            return str(self.idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0: [1, 2: [4], 3]\n"
     ]
    }
   ],
   "source": [
    "tree = Tree(0)\n",
    "tree.children.append(Tree(1))\n",
    "tree.children.append(Tree(2))\n",
    "tree.children.append(Tree(3))\n",
    "tree.children[1].children.append(Tree(4))\n",
    "print(tree)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Model\n",
    "The model is based on [child-sum tree LSTM](https://nlp.stanford.edu/pubs/tai-socher-manning-acl2015.pdf). For each sentence, the tree LSTM model extracts information following the dependency parse tree structure, and produces the sentence embedding at the root of each tree. This embedding can be used to predict semantic similarity.\n",
    "\n",
    "#### Child-sum Tree LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class ChildSumLSTMCell(Block):\n",
    "    def __init__(self, hidden_size,\n",
    "                 i2h_weight_initializer=None,\n",
    "                 hs2h_weight_initializer=None,\n",
    "                 hc2h_weight_initializer=None,\n",
    "                 i2h_bias_initializer='zeros',\n",
    "                 hs2h_bias_initializer='zeros',\n",
    "                 hc2h_bias_initializer='zeros',\n",
    "                 input_size=0, prefix=None, params=None):\n",
    "        super(ChildSumLSTMCell, self).__init__(prefix=prefix, params=params)\n",
    "        with self.name_scope():\n",
    "            self._hidden_size = hidden_size\n",
    "            self._input_size = input_size\n",
    "            self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),\n",
    "                                              init=i2h_weight_initializer)\n",
    "            self.hs2h_weight = self.params.get('hs2h_weight', shape=(3*hidden_size, hidden_size),\n",
    "                                               init=hs2h_weight_initializer)\n",
    "            self.hc2h_weight = self.params.get('hc2h_weight', shape=(hidden_size, hidden_size),\n",
    "                                               init=hc2h_weight_initializer)\n",
    "            self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),\n",
    "                                            init=i2h_bias_initializer)\n",
    "            self.hs2h_bias = self.params.get('hs2h_bias', shape=(3*hidden_size,),\n",
    "                                             init=hs2h_bias_initializer)\n",
    "            self.hc2h_bias = self.params.get('hc2h_bias', shape=(hidden_size,),\n",
    "                                             init=hc2h_bias_initializer)\n",
    "\n",
    "    def forward(self, F, inputs, tree):\n",
    "        children_outputs = [self.forward(F, inputs, child)\n",
    "                            for child in tree.children]\n",
    "        if children_outputs:\n",
    "            _, children_states = zip(*children_outputs) # unzip\n",
    "        else:\n",
    "            children_states = None\n",
    "\n",
    "        with inputs.context as ctx:\n",
    "            return self.node_forward(F, F.expand_dims(inputs[tree.idx], axis=0), children_states,\n",
    "                                     self.i2h_weight.data(ctx),\n",
    "                                     self.hs2h_weight.data(ctx),\n",
    "                                     self.hc2h_weight.data(ctx),\n",
    "                                     self.i2h_bias.data(ctx),\n",
    "                                     self.hs2h_bias.data(ctx),\n",
    "                                     self.hc2h_bias.data(ctx))\n",
    "\n",
    "    def node_forward(self, F, inputs, children_states,\n",
    "                     i2h_weight, hs2h_weight, hc2h_weight,\n",
    "                     i2h_bias, hs2h_bias, hc2h_bias):\n",
    "        # comment notation:\n",
    "        # N for batch size\n",
    "        # C for hidden state dimensions\n",
    "        # K for number of children.\n",
    "\n",
    "        # FC for i, f, u, o gates (N, 4*C), from input to hidden\n",
    "        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,\n",
    "                               num_hidden=self._hidden_size*4)\n",
    "        i2h_slices = F.split(i2h, num_outputs=4) # (N, C)*4\n",
    "        i2h_iuo = F.concat(*[i2h_slices[i] for i in [0, 2, 3]], dim=1) # (N, C*3)\n",
    "\n",
    "        if children_states:\n",
    "            # sum of children states, (N, C)\n",
    "            hs = F.add_n(*[state[0] for state in children_states])\n",
    "            # concatenation of children hidden states, (N, K, C)\n",
    "            hc = F.concat(*[F.expand_dims(state[0], axis=1) for state in children_states], dim=1)\n",
    "            # concatenation of children cell states, (N, K, C)\n",
    "            cs = F.concat(*[F.expand_dims(state[1], axis=1) for state in children_states], dim=1)\n",
    "\n",
    "            # calculate activation for forget gate. addition in f_act is done with broadcast\n",
    "            i2h_f_slice = i2h_slices[1]\n",
    "            f_act = i2h_f_slice + hc2h_bias + F.dot(hc, hc2h_weight) # (N, K, C)\n",
    "            forget_gates = F.Activation(f_act, act_type='sigmoid') # (N, K, C)\n",
    "        else:\n",
    "            # for leaf nodes, summation of children hidden states are zeros.\n",
    "            hs = F.zeros_like(i2h_slices[0])\n",
    "\n",
    "        # FC for i, u, o gates, from summation of children states to hidden state\n",
    "        hs2h_iuo = F.FullyConnected(data=hs, weight=hs2h_weight, bias=hs2h_bias,\n",
    "                                    num_hidden=self._hidden_size*3)\n",
    "        i2h_iuo = i2h_iuo + hs2h_iuo\n",
    "\n",
    "        iuo_act_slices = F.SliceChannel(i2h_iuo, num_outputs=3) # (N, C)*3\n",
    "        i_act, u_act, o_act = iuo_act_slices[0], iuo_act_slices[1], iuo_act_slices[2] # (N, C) each\n",
    "\n",
    "        # calculate gate outputs\n",
    "        in_gate = F.Activation(i_act, act_type='sigmoid')\n",
    "        in_transform = F.Activation(u_act, act_type='tanh')\n",
    "        out_gate = F.Activation(o_act, act_type='sigmoid')\n",
    "\n",
    "        # calculate cell state and hidden state\n",
    "        next_c = in_gate * in_transform\n",
    "        if children_states:\n",
    "            next_c = F.sum(forget_gates * cs, axis=1) + next_c\n",
    "        next_h = out_gate * F.Activation(next_c, act_type='tanh')\n",
    "\n",
    "        return next_h, [next_h, next_c]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Similarity regression module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# module for distance-angle similarity\n",
    "class Similarity(nn.Block):\n",
    "    def __init__(self, sim_hidden_size, rnn_hidden_size, num_classes):\n",
    "        super(Similarity, self).__init__()\n",
    "        with self.name_scope():\n",
    "            self.wh = nn.Dense(sim_hidden_size, in_units=2*rnn_hidden_size)\n",
    "            self.wp = nn.Dense(num_classes, in_units=sim_hidden_size)\n",
    "\n",
    "    def forward(self, F, lvec, rvec):\n",
    "        # lvec and rvec will be tree_lstm cell states at roots\n",
    "        mult_dist = F.broadcast_mul(lvec, rvec)\n",
    "        abs_dist = F.abs(F.add(lvec,-rvec))\n",
    "        vec_dist = F.concat(*[mult_dist, abs_dist],dim=1)\n",
    "        out = F.log_softmax(self.wp(F.sigmoid(self.wh(vec_dist))))\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Final model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# putting the whole model together\n",
    "class SimilarityTreeLSTM(nn.Block):\n",
    "    def __init__(self, sim_hidden_size, rnn_hidden_size, embed_in_size, embed_dim, num_classes):\n",
    "        super(SimilarityTreeLSTM, self).__init__()\n",
    "        with self.name_scope():\n",
    "            self.embed = nn.Embedding(embed_in_size, embed_dim)\n",
    "            self.childsumtreelstm = ChildSumLSTMCell(rnn_hidden_size, input_size=embed_dim)\n",
    "            self.similarity = Similarity(sim_hidden_size, rnn_hidden_size, num_classes)\n",
    "\n",
    "    def forward(self, F, l_inputs, r_inputs, l_tree, r_tree):\n",
    "        l_inputs = self.embed(l_inputs)\n",
    "        r_inputs = self.embed(r_inputs)\n",
    "        # get cell states at roots\n",
    "        lstate = self.childsumtreelstm(F, l_inputs, l_tree)[1][1]\n",
    "        rstate = self.childsumtreelstm(F, r_inputs, r_tree)[1][1]\n",
    "        output = self.similarity(F, lstate, rstate)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset classes\n",
    "#### Vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import logging\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "import numpy as np\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "\n",
    "import mxnet as mx\n",
    "\n",
    "# class for vocabulary and the word embeddings\n",
    "class Vocab(object):\n",
    "    # constants for special tokens: padding, unknown, and beginning/end of sentence.\n",
    "    PAD, UNK, BOS, EOS = 0, 1, 2, 3\n",
    "    PAD_WORD, UNK_WORD, BOS_WORD, EOS_WORD = '<blank>', '<unk>', '<s>', '</s>'\n",
    "    def __init__(self, filepaths=[], embedpath=None, include_unseen=False, lower=False):\n",
    "        self.idx2tok = []\n",
    "        self.tok2idx = {}\n",
    "        self.lower = lower\n",
    "        self.include_unseen = include_unseen\n",
    "\n",
    "        self.add(Vocab.PAD_WORD)\n",
    "        self.add(Vocab.UNK_WORD)\n",
    "        self.add(Vocab.BOS_WORD)\n",
    "        self.add(Vocab.EOS_WORD)\n",
    "\n",
    "        self.embed = None\n",
    "\n",
    "        for filename in filepaths:\n",
    "            logging.info('loading %s'%filename)\n",
    "            with open(filename, 'r') as f:\n",
    "                self.load_file(f)\n",
    "        if embedpath is not None:\n",
    "            logging.info('loading %s'%embedpath)\n",
    "            with open(embedpath, 'r') as f:\n",
    "                self.load_embedding(f, reset=set([Vocab.PAD_WORD, Vocab.UNK_WORD, Vocab.BOS_WORD,\n",
    "                                                  Vocab.EOS_WORD]))\n",
    "\n",
    "    @property\n",
    "    def size(self):\n",
    "        return len(self.idx2tok)\n",
    "\n",
    "    def get_index(self, key):\n",
    "        return self.tok2idx.get(key.lower() if self.lower else key,\n",
    "                                Vocab.UNK)\n",
    "\n",
    "    def get_token(self, idx):\n",
    "        if idx < self.size:\n",
    "            return self.idx2tok[idx]\n",
    "        else:\n",
    "            return Vocab.UNK_WORD\n",
    "\n",
    "    def add(self, token):\n",
    "        token = token.lower() if self.lower else token\n",
    "        if token in self.tok2idx:\n",
    "            idx = self.tok2idx[token]\n",
    "        else:\n",
    "            idx = len(self.idx2tok)\n",
    "            self.idx2tok.append(token)\n",
    "            self.tok2idx[token] = idx\n",
    "        return idx\n",
    "\n",
    "    def to_indices(self, tokens, add_bos=False, add_eos=False):\n",
    "        vec = [BOS] if add_bos else []\n",
    "        vec += [self.get_index(token) for token in tokens]\n",
    "        if add_eos:\n",
    "            vec.append(EOS)\n",
    "        return vec\n",
    "\n",
    "    def to_tokens(self, indices, stop):\n",
    "        tokens = []\n",
    "        for i in indices:\n",
    "            tokens += [self.get_token(i)]\n",
    "            if i == stop:\n",
    "                break\n",
    "        return tokens\n",
    "\n",
    "    def load_file(self, f):\n",
    "        for line in f:\n",
    "            tokens = line.rstrip('\\n').split()\n",
    "            for token in tokens:\n",
    "                self.add(token)\n",
    "\n",
    "    def load_embedding(self, f, reset=[]):\n",
    "        vectors = {}\n",
    "        for line in tqdm(f.readlines(), desc='Loading embeddings'):\n",
    "            tokens = line.rstrip('\\n').split(' ')\n",
    "            word = tokens[0].lower() if self.lower else tokens[0]\n",
    "            if self.include_unseen:\n",
    "                self.add(word)\n",
    "            if word in self.tok2idx:\n",
    "                vectors[word] = [float(x) for x in tokens[1:]]\n",
    "        dim = len(vectors.values()[0])\n",
    "        def to_vector(tok):\n",
    "            if tok in vectors and tok not in reset:\n",
    "                return vectors[tok]\n",
    "            elif tok not in vectors:\n",
    "                return np.random.normal(-0.05, 0.05, size=dim)\n",
    "            else:\n",
    "                return [0.0]*dim\n",
    "        self.embed = mx.nd.array([vectors[tok] if tok in vectors and tok not in reset\n",
    "                                  else [0.0]*dim for tok in self.idx2tok])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Data iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Iterator class for SICK dataset\n",
    "class SICKDataIter(object):\n",
    "    def __init__(self, path, vocab, num_classes, shuffle=True):\n",
    "        super(SICKDataIter, self).__init__()\n",
    "        self.vocab = vocab\n",
    "        self.num_classes = num_classes\n",
    "        self.l_sentences = []\n",
    "        self.r_sentences = []\n",
    "        self.l_trees = []\n",
    "        self.r_trees = []\n",
    "        self.labels = []\n",
    "        self.size = 0\n",
    "        self.shuffle = shuffle\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        if self.shuffle:\n",
    "            mask = list(range(self.size))\n",
    "            random.shuffle(mask)\n",
    "            self.l_sentences = [self.l_sentences[i] for i in mask]\n",
    "            self.r_sentences = [self.r_sentences[i] for i in mask]\n",
    "            self.l_trees = [self.l_trees[i] for i in mask]\n",
    "            self.r_trees = [self.r_trees[i] for i in mask]\n",
    "            self.labels = [self.labels[i] for i in mask]\n",
    "        self.index = 0\n",
    "\n",
    "    def next(self):\n",
    "        out = self[self.index]\n",
    "        self.index += 1\n",
    "        return out\n",
    "\n",
    "    def set_context(self, context):\n",
    "        self.l_sentences = [a.as_in_context(context) for a in self.l_sentences]\n",
    "        self.r_sentences = [a.as_in_context(context) for a in self.r_sentences]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.size\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        l_tree = self.l_trees[index]\n",
    "        r_tree = self.r_trees[index]\n",
    "        l_sent = self.l_sentences[index]\n",
    "        r_sent = self.r_sentences[index]\n",
    "        label = self.labels[index]\n",
    "        return (l_tree, l_sent, r_tree, r_sent, label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training with autograd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:==> SICK vocabulary size : 2412 \n",
      "INFO:root:==> Size of train data   : 4500 \n",
      "INFO:root:==> Size of dev data     : 500 \n",
      "INFO:root:==> Size of test data    : 4927 \n",
      "Training epoch 0: 100%|██████████| 250/250 [00:11<00:00, 21.48it/s]\n",
      "INFO:root:training acc at epoch 0: pearsonr=0.096197\n",
      "INFO:root:training acc at epoch 0: mse=1.138699\n",
      "Testing in validation mode: 100%|██████████| 500/500 [00:09<00:00, 51.57it/s]\n",
      "INFO:root:validation acc: pearsonr=0.490352\n",
      "INFO:root:validation acc: mse=1.237509\n",
      "INFO:root:New optimum found: 0.49035187610029013.\n"
     ]
    }
   ],
   "source": [
    "import argparse, pickle, math, os, random\n",
    "import logging\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "import numpy as np\n",
    "\n",
    "import mxnet as mx\n",
    "from mxnet import gluon\n",
    "from mxnet.gluon import nn\n",
    "from mxnet import autograd as ag\n",
    "\n",
    "# training settings and hyper-parameters\n",
    "use_gpu = False\n",
    "optimizer = 'AdaGrad'\n",
    "seed = 123\n",
    "batch_size = 25\n",
    "training_batches_per_epoch = 10\n",
    "learning_rate = 0.01\n",
    "weight_decay = 0.0001\n",
    "epochs = 1\n",
    "rnn_hidden_size, sim_hidden_size, num_classes = 150, 50, 5\n",
    "\n",
    "# initialization\n",
    "context = [mx.gpu(0) if use_gpu else mx.cpu()]\n",
    "\n",
    "# seeding\n",
    "mx.random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "\n",
    "# read dataset\n",
    "def verified(file_path, sha1hash):\n",
    "    import hashlib\n",
    "    sha1 = hashlib.sha1()\n",
    "    with open(file_path, 'rb') as f:\n",
    "        while True:\n",
    "            data = f.read(1048576)\n",
    "            if not data:\n",
    "                break\n",
    "            sha1.update(data)\n",
    "    matched = sha1.hexdigest() == sha1hash\n",
    "    if not matched:\n",
    "        logging.warn('Found hash mismatch in file {}, possibly due to incomplete download.'\n",
    "                     .format(file_path))\n",
    "    return matched\n",
    "\n",
    "data_file_name = 'tree_lstm_dataset-3d85a6c4.cPickle'\n",
    "data_file_hash = '3d85a6c44a335a33edc060028f91395ab0dcf601'\n",
    "if not os.path.exists(data_file_name) or not verified(data_file_name, data_file_hash):\n",
    "    from mxnet.test_utils import download\n",
    "    download('https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/%s'%data_file_name,\n",
    "             overwrite=True)\n",
    "\n",
    "\n",
    "with open('tree_lstm_dataset-3d85a6c4.cPickle', 'rb') as f:\n",
    "    train_iter, dev_iter, test_iter, vocab = pickle.load(f)\n",
    "\n",
    "logging.info('==> SICK vocabulary size : %d ' % vocab.size)\n",
    "logging.info('==> Size of train data   : %d ' % len(train_iter))\n",
    "logging.info('==> Size of dev data     : %d ' % len(dev_iter))\n",
    "logging.info('==> Size of test data    : %d ' % len(test_iter))\n",
    "\n",
    "# get network\n",
    "net = SimilarityTreeLSTM(sim_hidden_size, rnn_hidden_size, vocab.size, vocab.embed.shape[1], num_classes)\n",
    "\n",
    "# use pearson correlation and mean-square error for evaluation\n",
    "metric = mx.metric.create(['pearsonr', 'mse'])\n",
    "\n",
    "# the prediction from the network is log-probability vector of each score class\n",
    "# so use the following function to convert scalar score to the vector\n",
    "# e.g 4.5 -> [0, 0, 0, 0.5, 0.5]\n",
    "def to_target(x):\n",
    "    target = np.zeros((1, num_classes))\n",
    "    ceil = int(math.ceil(x))\n",
    "    floor = int(math.floor(x))\n",
    "    if ceil==floor:\n",
    "        target[0][floor-1] = 1\n",
    "    else:\n",
    "        target[0][floor-1] = ceil - x\n",
    "        target[0][ceil-1] = x - floor\n",
    "    return mx.nd.array(target)\n",
    "\n",
    "# and use the following to convert log-probability vector to score\n",
    "def to_score(x):\n",
    "    levels = mx.nd.arange(1, 6, ctx=x.context)\n",
    "    return [mx.nd.sum(levels*mx.nd.exp(x), axis=1).reshape((-1,1))]\n",
    "\n",
    "# when evaluating in validation mode, check and see if pearson-r is improved\n",
    "# if so, checkpoint and run evaluation on test dataset\n",
    "def test(ctx, data_iter, best, mode='validation', num_iter=-1):\n",
    "    data_iter.reset()\n",
    "    samples = len(data_iter)\n",
    "    data_iter.set_context(ctx[0])\n",
    "    preds = []\n",
    "    labels = [mx.nd.array(data_iter.labels, ctx=ctx[0]).reshape((-1,1))]\n",
    "    for _ in tqdm(range(samples), desc='Testing in {} mode'.format(mode)):\n",
    "        l_tree, l_sent, r_tree, r_sent, label = data_iter.next()\n",
    "        z = net(mx.nd, l_sent, r_sent, l_tree, r_tree)\n",
    "        preds.append(z)\n",
    "\n",
    "    preds = to_score(mx.nd.concat(*preds, dim=0))\n",
    "    metric.update(preds, labels)\n",
    "    names, values = metric.get()\n",
    "    metric.reset()\n",
    "    for name, acc in zip(names, values):\n",
    "        logging.info(mode+' acc: %s=%f'%(name, acc))\n",
    "        if name == 'pearsonr':\n",
    "            test_r = acc\n",
    "    if mode == 'validation' and num_iter >= 0:\n",
    "        if test_r >= best:\n",
    "            best = test_r\n",
    "            logging.info('New optimum found: {}.'.format(best))\n",
    "        return best\n",
    "\n",
    "\n",
    "def train(epoch, ctx, train_data, dev_data):\n",
    "    # initialization with context\n",
    "    if isinstance(ctx, mx.Context):\n",
    "        ctx = [ctx]\n",
    "    net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx[0])\n",
    "    net.embed.weight.set_data(vocab.embed.as_in_context(ctx[0]))\n",
    "    train_data.set_context(ctx[0])\n",
    "    dev_data.set_context(ctx[0])\n",
    "\n",
    "    # set up trainer for optimizing the network.\n",
    "    trainer = gluon.Trainer(net.collect_params(), optimizer, {'learning_rate': learning_rate, 'wd': weight_decay})\n",
    "\n",
    "    best_r = -1\n",
    "    Loss = gluon.loss.KLDivLoss()\n",
    "    for i in range(epoch):\n",
    "        train_data.reset()\n",
    "        num_samples = min(len(train_data), training_batches_per_epoch*batch_size)\n",
    "        # collect predictions and labels for evaluation metrics\n",
    "        preds = []\n",
    "        labels = [mx.nd.array(train_data.labels[:num_samples], ctx=ctx[0]).reshape((-1,1))]\n",
    "        for j in tqdm(range(num_samples), desc='Training epoch {}'.format(i)):\n",
    "            # get next batch\n",
    "            l_tree, l_sent, r_tree, r_sent, label = train_data.next()\n",
    "            # use autograd to record the forward calculation\n",
    "            with ag.record():\n",
    "                # forward calculation. the output is log probability\n",
    "                z = net(mx.nd, l_sent, r_sent, l_tree, r_tree)\n",
    "                # calculate loss\n",
    "                loss = Loss(z, to_target(label).as_in_context(ctx[0]))\n",
    "                # backward calculation for gradients.\n",
    "                loss.backward()\n",
    "                preds.append(z)\n",
    "            # update weight after every batch_size samples\n",
    "            if (j+1) % batch_size == 0:\n",
    "                trainer.step(batch_size)\n",
    "\n",
    "        # translate log-probability to scores, and evaluate\n",
    "        preds = to_score(mx.nd.concat(*preds, dim=0))\n",
    "        metric.update(preds, labels)\n",
    "        names, values = metric.get()\n",
    "        metric.reset()\n",
    "        for name, acc in zip(names, values):\n",
    "            logging.info('training acc at epoch %d: %s=%f'%(i, name, acc))\n",
    "        best_r = test(ctx, dev_data, best_r, num_iter=i)\n",
    "\n",
    "train(epochs, context, train_iter, dev_iter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conclusion\n",
    "- Gluon offers great tools for modeling in an imperative way."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
